from torch.utils import data
from scipy.io import loadmat
from utils.saveNet import *
import numpy as np

def sort_files(flist):
    return sorted(flist, key=lambda x: 1000 * parse_dat_filename(x)['slc'] + parse_dat_filename(x)['phs'])

def parse_dat_filename(filename):
    args = {}
    for s in filename[:-4].split('_'):
        if 'slc' in s:
            args['slc'] = int(s[3:])
        elif 'phs' in s:
            args['phs'] = int(s[3:])
        elif 'line' in s:
            args['line'] = int(s[4:])
        elif 'lins' in s:
            args['lins'] = int(s[4:])
        elif 'cols' in s:
            args['cols'] = int(s[4:])
        elif 'cha' in s:
            args['cha'] = int(s[3:])
    return args

def getDatasetGenerators(params):
    params.num_slices_per_patient = []
    params.input_slices = []
    params.groundTruth_slices = []
    params.us_rates = []
    params.patients = []
    params.training_patients_index = []

    for dir in params.dir:
        datasets_dirs = sorted(os.listdir(dir + 'image/'), key=lambda x: int(x))
        for i, dst in enumerate(datasets_dirs):
            params.patients.append(dst)
            kspaces = sort_files(os.listdir(dir + 'kspace/' + dst))
            params.num_slices_per_patient.append(len(kspaces))
            for j, ksp in enumerate(kspaces):
                params.input_slices.append(dir + 'kspace/' + dst + '/' + ksp)

            '''read coil-combined 1-channel complex-valued data from .mat files'''
            images = sort_files(os.listdir(dir + 'ref/' + dst))
            for j, img in enumerate(images):
                params.groundTruth_slices.append(dir + 'ref/' + dst + '/' + img)

    print('-- Number of Datasets: ' + str(len(params.patients)))

    ### Load Training & Testing indices
    tr_idx = loadmat('training_indices.mat')['data']
    tst_idx = loadmat('testing_indices.mat')['data']

    dim = params.img_size[:]
    dim.append(2)

    training_DS = DataGenerator(input_IDs=params.input_slices[tr_idx],
                                output_IDs=params.groundTruth_slices[tr_idx],
                                params=params
                                )

    validation_DS = DataGenerator(input_IDs=params.input_slices[tst_idx],
                                  output_IDs=params.groundTruth_slices[tst_idx],
                                  params=params
                                  )

    training_DL = data.DataLoader(training_DS, batch_size=params.batch_size, shuffle=True,
                                  num_workers=params.data_loders_num_workers)

    validation_DL = data.DataLoader(validation_DS, batch_size=params.batch_size, shuffle=False,
                                    num_workers=params.data_loders_num_workers)

    return training_DL, validation_DL, params


def get_moving_window(indx, num_sl, total_num_sl):
    if indx - num_sl // 2 < 1:
        return range(1, num_sl + 1)

    if indx + num_sl // 2 > total_num_sl:
        return range(total_num_sl - num_sl + 1, total_num_sl + 1)

    return range(indx - num_sl // 2, indx + num_sl // 2 + 1)


class DataGenerator(data.Dataset):
    '''Generates data for Pytorch'''

    def __init__(self, input_IDs, output_IDs, params=None, nums_slices=None, mode='training'):
        '''Initialization'''

        self.output_IDs = output_IDs
        self.input_IDs = input_IDs
        self.dim = params.img_size[:]
        self.dim.append(2)
        self.n_channels = params.n_channels
        self.n_spokes = params.n_spokes
        self.nums_slices = nums_slices
        self.complex_net = params.complex_net
        self.mode = mode
        self.params = params

    def __len__(self):
        'Denotes the number of batches per epoch'
        return len(self.input_IDs)

    def shuffel_cases(self):
        sh_idx = list(range(0, len(self.input_IDs), self.params.num_phases))
        np.random.shuffle(sh_idx)
        rnds = np.asarray([list(range(id, id+25)) for id in sh_idx])
        rnds = rnds.reshape((rnds.shape[0]*rnds.shape[1]))
        self.input_IDs = [self.input_IDs[pid] for pid in rnds]
        self.output_IDs = [self.output_IDs[pid] for pid in rnds]

    def __getitem__(self, index):
        'Generate one batch of data'
        X, y = self.generate_radial_cine_mvw(self.input_IDs[index],
                                            self.output_IDs[index])
        return X, y, self.input_IDs[index]

    def generate_radial_cine_mvw(self, input_ID, output_ID):

        ## load input k-space data from a moving window of time frames...
        input_ks_mw = loadmat(input_ID)['data']
        input_ks_mw = np.moveaxis(input_ks_mw, [2], [0])

        ##################################################################
        ## load refernece fullysampled images
        output = loadmat(output_ID)['data']
        output = np.moveaxis(output, [2], [0])

        return input_ks_mw, output

